/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.amazonaws.services.glue.marketplace.connector.tpcds

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.types.StructType
import com.teradata.tpcds

import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

/** Generate the requested table data.
 *
 * @param tpcdsInputPartition The partition which each Spark task processes.
 * @see [[tpcdsInputPartition]]
 */
class TPCDSPartitionReader(tpcdsInputPartition: TPCDSInputPartition)
  extends PartitionReader[InternalRow]
{
  protected val scale: Int = tpcdsInputPartition.scale
  protected val table: tpcds.Table = tpcdsInputPartition.table
  protected val parallelism: Int = tpcdsInputPartition.numPartitions
  protected val chunkNumber: Int = tpcdsInputPartition.chunkNumber
  protected val schema: StructType = tpcdsInputPartition.schema

  private lazy val serializer: ExpressionEncoder.Serializer[Row] = RowEncoder.apply(schema).resolveAndBind().createSerializer()
  protected val itr: Iterator[util.List[util.List[String]]] =
    TPCDSUtils.generateChunkIterator(table, scale, parallelism, chunkNumber)

  override def next(): Boolean = itr.hasNext

  // The dataset is generated by each row.
  // Each row is added to ArrayBuffer, and each column value is converted to Scala type if it's not null.
  override def get(): InternalRow = {
    val row: Seq[String] = itr.next().get(0).asScala
    val rowBuf = new ArrayBuffer[Any]()
    for(i <- row.indices) {
      row(i) match {
        case null => rowBuf.append(null) // If isEmpty is used, NPE will be thrown.
        case _ => rowBuf.append(TPCDSUtils.convertValueType(row(i), schema.toList(i).dataType))
      }
    }

    // Since Spark 3.0.0., toRow (and fromRow) in ExpressionEncoder were replaced with Serializer/Deserializer
    // Note that don't initialize the serializer here. Initializing it affects the job performance.
    serializer.apply(Row.fromSeq(rowBuf))
  }

  import java.io.IOException
  @throws[IOException]
  override def close(): Unit = {}
}

/** Generate the requested table data partially.
 *
 * @param tpcdsSingleChunkInputPartition The partition for single chunk which each Spark task processes.
 * @see [[tpcdsSingleChunkInputPartition]]
 * @see [[TPCDSPartitionReader]]
 */
class TPCDSSingleChunkPartitionReader(tpcdsSingleChunkInputPartition: TPCDSSingleChunkInputPartition)
  extends TPCDSPartitionReader(tpcdsSingleChunkInputPartition)
{
  final private val start = tpcdsSingleChunkInputPartition.start
  final private val end = tpcdsSingleChunkInputPartition.end
  final override protected val itr: Iterator[util.List[util.List[String]]] =
    TPCDSUtils.generateChunkIterator(table, scale, parallelism, chunkNumber).slice(start, end)
}